from .base_reasoner import BaseReasoner, ReasoningNode
import asyncio
import argparse
import json
import os
import re
import time
import traceback
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
import openai
from collections import defaultdict
from datetime import datetime

class OpenBookQAReasoner(BaseReasoner):
    def __init__(self):
        super().__init__("OpenBookQA")
        self.config.dataset_path = "datasets/OpenBookQA.json"
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load OpenBookQA problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    async def execute_workflow(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Execute full reasoning workflow for an OpenBookQA problem"""
        try:
            question = problem["question_stem"]
            choices = problem["choices"]
            options = {
                "A": choices["text"][0],
                "B": choices["text"][1],
                "C": choices["text"][2],
                "D": choices["text"][3]
            }
            
            # Step 1: Create root node
            root = self._create_node(
                question=question,
                options=options,
                constraints={},
                path=[],
                method={"description": "Original problem"}
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step 2: Extract constraints
            constraints = await self._extract_constraints(question, options)
            root.constraints = constraints
            self._log_step("step2", root.node_id, {"constraints": constraints})
            
            # Modified: Skip path exploration steps (original steps 3 and 4)
            # Create a single method node with default method
            default_method = {
                "description": "Direct reasoning with constraints",
                "steps": [
                    "Analyze question and constraints",
                    "Evaluate each option against constraints",
                    "Select best matching option"
                ],
                "score": 80,
                "score_reason": "Default method for ablation study"
            }
            
            method_node = self._create_node(
                path=[root.node_id],
                question=question,
                options=options,
                method=default_method,
                constraints=root.constraints,
                score=default_method["score"],
                parent_id=root.node_id
            )
            root.children.append(method_node.node_id)
            self._log_step("step4_ablation", method_node.node_id, {"method": default_method})
            
            # Step 5: Check classification for the method
            classification = await self._check_classification(
                method_node.method["description"],
                question,
                options
            )
            self._log_step("step5", method_node.node_id, {"classification": classification})
            
            if classification["need_classify"]:
                # Step 6: Create classification nodes
                for case in classification["cases"]:
                    # Merge constraints
                    combined_constraints = {
                        "explicit": method_node.constraints.get("explicit", []).copy(),
                        "implicit": method_node.constraints.get("implicit", []).copy()
                    }
                    
                    # Add case-specific constraints
                    for k, v in case["constraints"].items():
                        if k in combined_constraints:
                            combined_constraints[k].append(v)
                        else:
                            combined_constraints.setdefault("implicit", []).append(f"{k}: {v}")
                    
                    node = self._create_node(
                        path=method_node.path + [method_node.node_id],
                        question=question,
                        options=options,
                        method=method_node.method,
                        constraints=combined_constraints,
                        score=method_node.score,
                        parent_id=method_node.node_id
                    )
                    method_node.children.append(node.node_id)
                    self.temp_list.append(node.node_id)
                    self._log_step("step6", node.node_id, {"case": case})
            else:
                self.temp_list.append(method_node.node_id)
            
            # Step 7: Solve nodes
            solutions = []
            for node_id in self.temp_list:
                solution = await self._solve_node(node_id)
                if solution:
                    solutions.append(solution)
                    self._log_step("step7", node_id, {"solution": solution})
            
            # Step 8: Aggregate answers
            final_answer = await self._aggregate_answers(solutions)
            self._log_step("step8", "system", {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "nodes": self.nodes,
                "logs": self.logs,
                "token_usage": self.llm.token_counts
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }
    
    async def _extract_constraints(self, question: str, options: Dict[str, str]) -> Dict[str, Any]:
        """Extract constraints from problem and options"""
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Analyze this question and extract key constraints:

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Identify:
1. Explicit constraints (directly stated)
2. Implicit constraints (logical implications)
3. Key terms and their relationships
4. Spatial/temporal relationships if present
5. Any conditional statements

Output JSON format:
{{
    "explicit": ["list", "of", "constraints"],
    "implicit": ["list", "of", "constraints"],
    "key_terms": ["term1", "term2"],
    "notes": "Analysis summary"
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                return json.loads(response)
            except:
                continue
        
        return {
            "explicit": [],
            "implicit": [],
            "key_terms": [],
            "notes": "Failed to extract constraints"
        }
    
    async def _check_classification(self, method: str, question: str, options: Dict[str, str]) -> Dict[str, Any]:
        """Step 5: Determine if classification needed"""
        options_text = "\n".join([f"{k}. {v}" for k, v in options.items()])
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Determine if this solution approach requires case classification:

Solution Approach: {method}
Question: {question}
Options:
{options_text}

Consider:
1. Does the question contain multiple scenarios or cases?
2. Are there conditional statements that create distinct possibilities?
3. Do the options represent different logical paths?
4. Would different initial assumptions lead to different solutions?

If classification needed, provide:
- Comprehensive case descriptions
- Precise conditions for each case
- Expected outcomes

Output JSON format:
{{
    "need_classify": true/false,
    "reason": "Classification rationale",
    "cases": [
        {{
            "description": "Case description",
            "constraints": {{"parameter": "value_range"}}
        }}
    ]
}}"""
        
        try:
            response = await self.llm.generate(prompt, response_format="json_object")
            data = json.loads(response)
            return data
        except:
            return {
                "need_classify": False,
                "reason": "Analysis failed",
                "cases": []
            }
    
    async def _solve_node(self, node_id: str) -> Optional[Dict[str, Any]]:
        """Step 7: Solve individual reasoning node"""
        node = self.nodes[node_id]
        
        # Build context prompt
        context = f"Question: {node.question}\nOptions:\n"
        for opt, text in node.options.items():
            context += f"{opt}. {text}\n"
        
        context += f"\nSolution Approach: {node.method['description']}\n"
        context += f"Constraints: {json.dumps(node.constraints, indent=2)}\n"
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Solve this question using the specified approach:

{context}

Reasoning Steps:
1. Strictly follow the provided approach: {node.method['description']}
2. Execute each step: {', '.join(node.method['steps'])}
3. Consider all constraints
4. Evaluate each option systematically
5. Provide clear justification for inclusion/exclusion
6. Select the best answer

Output Requirements:
- End your response with: "Final Answer: [OPTION]"
- Use \boxed{{[OPTION]}} to denote your answer
- Your answer must be A, B, C, or D
"""
        
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        
        if answer:
            node.answer = answer
            node.state = "solved"
            return {
                "node_id": node_id,
                "response": response,
                "answer": answer
            }
        return None
    
    async def _aggregate_answers(self, solutions: List[Dict[str, Any]]) -> str:
        """Step 8: Aggregate answers from multiple nodes"""
        if not solutions:
            return "X"  # Invalid answer
        
        # If only one solution, return it
        if len(solutions) == 1:
            return solutions[0]["answer"]
        
        # If all solutions agree, return consensus
        answers = [s["answer"] for s in solutions]
        if len(set(answers)) == 1:
            return answers[0]
        
        # Build aggregation prompt
        solutions_text = ""
        for i, sol in enumerate(solutions):
            node = self.nodes[sol["node_id"]]
            solutions_text += f"\n\nSolution {i+1} (Node {sol['node_id']}):"
            solutions_text += f"\nApproach: {node.method['description']}"
            solutions_text += f"\nConstraints: {json.dumps(node.constraints, indent=2)}"
            solutions_text += f"\nAnswer: {sol['answer']}"
            solutions_text += f"\nReasoning Excerpt:\n{sol['response'][:]}..."
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Synthesize these approaches:

{solutions_text}

Instructions:
1. Analyze all solutions and their approaches
2. Identify the most reliable reasoning
3. Verify consistency with constraints
4. Select the best overall answer
5. Output format: \boxed{{[ANSWER]}}
"""
        
        response = await self.llm.generate(prompt)
        return self._extract_answer(response) or "X"
    
    def save_results(self, result: Dict[str, Any], problem: Dict[str, Any]) -> Dict[str, Any]:
        # Convert nodes to serializable format
        serialized_nodes = {}
        for node_id, node in self.nodes.items():
            serialized_nodes[node_id] = {
                "node_id": node.node_id,
                "question": node.question,
                "options": node.options,
                "method": node.method,
                "constraints": node.constraints,
                "answer": node.answer,
                "state": node.state,
                "score": node.score
            }
        
        # Prepare verification
        selected_answer = result.get("final_answer", "X")
        correct_answer = problem.get("answerKey", "").strip().upper()
        is_correct = self.verify_answer(problem, selected_answer)
        verification = {
            "is_correct": is_correct,
            "correct_answer": correct_answer,
            "given_answer": selected_answer
        }
        return {
            "problem": problem,
            "result": {
                "final_answer": selected_answer,
                "correct_answer": correct_answer,
                "is_correct": is_correct,
                "nodes": serialized_nodes,
                "token_usage": result.get("token_usage", [0, 0])
            },
            "verification": verification
        }

    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text"""
        # Match \boxed{answer} pattern
        boxed_pattern = r'\\boxed\{([A-D])\}'
        boxed_match = re.search(boxed_pattern, text)
        if boxed_match:
            return boxed_match.group(1)
        
        # Match "Answer: X" pattern
        answer_pattern = r'Answer:\s*([A-D])'
        answer_match = re.search(answer_pattern, text, re.IGNORECASE)
        if answer_match:
            return answer_match.group(1)
        
        # Match standalone option letter
        option_pattern = r'\b([A-D])\b(?!\.\w)'
        option_match = re.search(option_pattern, text)
        if option_match:
            return option_match.group(1)
        
        return None

    def verify_answer(self, problem: Dict[str, Any], selected_answer: str) -> bool:
        """Verify if selected answer matches correct option"""
        correct_answer = problem.get("answerKey", "").strip().upper()
        return selected_answer.upper() == correct_answer.upper()